'''
This is a pseudo-code to help you understand the paper.
The entire source code is planned to be released to public.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.weight_norm as wn
import torch.nn.utils.spectral_norm as sn
import torch.distributions as D
from .conv import *



class TextEnc(nn.Module):
    def __init__(self, hp):
        super(TextEnc, self).__init__()
        self.Embedding = nn.Embedding(hp.n_symbols, hp.hidden_dim)
        self.conv_layers = nn.ModuleList([Conv1d(hp.hidden_dim, 2*hp.hidden_dim, 5) for _ in range(7)])
    
    def forward(self, text):
        embedded = F.dropout(self.Embedding(text), 0.1, training=self.training)
        x = embedded.transpose(1,2)
        
        for conv in self.conv_layers:
            x1, x2 = torch.chunk( conv(x), 2, dim=1)
            x = (x1 * torch.sigmoid(x2) + x) / 2**0.5
            x = F.dropout(x, 0.1, training=self.training)
            
        key = x.transpose(1, 2)
        value = (key+embedded)/2**0.5
        
        return key, value


    
class BVAE_Layer(nn.Module):
    def __init__(self, hdim, downsample=None, dilation=1.0):
        super(BVAE_Layer, self).__init__()
        self.downsample=downsample
        
        ####################### BOTTOM_UP #########################
        if downsample=='F':
            self.pre_conv = Conv1d(2*hdim, hdim, 5, activation=F.elu, dilation=dilation)
        else:
            self.pre_conv = Conv1d(hdim, hdim, 5, activation=F.elu, dilation=dilation)
            
        self.up_conv_a = nn.ModuleList([sn(Conv1d(hdim, hdim, 5, activation=F.elu)),
                                        sn(Conv1d(hdim, 3*hdim, 5, bias=False))])
        self.up_conv_b = sn(Conv1d(hdim, hdim, 5, activation=F.elu))
        
        ######################## TOP_DOWN ##########################
        self.down_conv_a = nn.ModuleList([sn(Conv1d(hdim, hdim, 5, activation=F.elu)),
                                          sn(Conv1d(hdim, 5*hdim, 5, bias=False))])
        self.down_conv_b = nn.ModuleList([sn(Conv1d(2*hdim, hdim, 5, bias=False)),
                                          sn(Conv1d(hdim, hdim, 5, activation=F.elu))])
        
        if downsample=='F':
            self.post_conv = Conv1d(hdim, 2*hdim, 5, activation=F.elu, dilation=dilation)
        else:
            self.post_conv = Conv1d(hdim, hdim, 5, activation=F.elu, dilation=dilation)
        
        
    def up(self, inputs):
        if self.downsample=='T':
            inputs = self.blur_pool(inputs)
            inputs = self.pre_conv(inputs)
        else:
            inputs = self.pre_conv(inputs)
        
        x = self.up_conv_a[0](inputs)
        self.qz_mean, self.qz_std, h = self.up_conv_a[1](x).chunk(3, 1)
        self.qz_std = F.softplus(self.qz_std)
        h = self.up_conv_b(h)

        return (inputs+h)/2**0.5
    
    
    def down(self, inputs, temperature=1.0):
        x = self.down_conv_a[0](inputs)
        pz_mean, pz_std, rz_mean, rz_std, h = self.down_conv_a[1](x).chunk(5, 1)
        pz_std, rz_std = F.softplus(pz_std), F.softplus(rz_std)
        
        prior = D.Normal(pz_mean, pz_std)
        posterior = D.Normal(pz_mean+self.qz_mean+rz_mean, pz_std*self.qz_std*rz_std)
        z = posterior.rsample()
        kl = D.kl.kl_divergence(posterior, prior).mean()
            
        h = torch.cat((z, h), 1)
        h = self.down_conv_b[0](h)
        h = self.down_conv_b[1](h)
        
        if self.downsample=='T':
            outputs = self.post_conv((inputs+h)/2**0.5).repeat_interleave(2,-1)
        else:
            outputs = self.post_conv((inputs+h)/2**0.5)
        
        return outputs, kl
    
    
    def blur_pool(self, x):
        blur_kernel = (torch.FloatTensor([[[1,2,1]]])/4.0).repeat(x.size(1),1,1).to(x.device)
        outputs = F.conv1d(x, blur_kernel, padding=1, stride=2, groups=x.size(1))
        return outputs
    
    
    
class TopDown(nn.Module):
    def __init__(self, hdim, downsample=None, dilation=1.0):
        super(TopDown, self).__init__()
        self.downsample=downsample
        
        ######################## TOP_DOWN ##########################
        self.down_conv_a = nn.ModuleList([sn(Conv1d(hdim, hdim, 5, activation=F.elu)),
                                          sn(Conv1d(hdim, 5*hdim, 5, bias=False))])
        self.down_conv_b = nn.ModuleList([sn(Conv1d(2*hdim, hdim, 5, bias=False)),
                                          sn(Conv1d(hdim, hdim, 5, activation=F.elu))])
        
        if downsample=='F':
            self.post_conv = Conv1d(hdim, 2*hdim, 5, activation=F.elu, dilation=dilation)
        else:
            self.post_conv = Conv1d(hdim, hdim, 5, activation=F.elu, dilation=dilation)
            
    
    def down(self, inputs, temperature=1.0):
        x = self.down_conv_a[0](inputs)
        pz_mean, pz_std, _, _, h = self.down_conv_a[1](x).chunk(5, 1)
        pz_std = F.softplus(pz_std)
        
        prior = D.Normal(pz_mean, pz_std*temperature)
        z = prior.rsample()
            
        h = torch.cat((z, h), 1)
        h = self.down_conv_b[0](h)
        h = self.down_conv_b[1](h)
        
        if self.downsample=='T':
            outputs = self.post_conv((inputs+h)/2**0.5).repeat_interleave(2,-1)
        else:
            outputs = self.post_conv((inputs+h)/2**0.5)
        
        return outputs
    

    
class DurationPredictor(nn.Module):
    def __init__(self, hp):
        super(DurationPredictor, self).__init__()
        self.conv1 = Conv1d(hp.hidden_dim, hp.hidden_dim, 3, bias=False, activation=F.elu)
        self.conv2 = Conv1d(hp.hidden_dim, hp.hidden_dim, 3, bias=False, activation=F.elu)
        
        self.ln1 = nn.LayerNorm(hp.hidden_dim)
        self.ln2 = nn.LayerNorm(hp.hidden_dim)
        self.dropout = nn.Dropout(0.1)
        
        self.linear = Linear(hp.hidden_dim, 1)

    def forward(self, h):
        x = self.conv1(h)
        x = self.dropout(self.ln1(x.transpose(1,2)))
        x = self.conv2(x.transpose(1,2))
        x = self.dropout(self.ln2(x.transpose(1,2)))
        out = self.linear(x).exp()+1
        
        return out.squeeze(-1)

    

class Prenet(nn.Module):
    def __init__(self, hp):
        super(Prenet, self).__init__()
        self.layers = nn.ModuleList([Conv1d(hp.n_mel_channels, hp.hidden_dim, 1, bias=True, activation=F.elu),
                                     Conv1d(hp.hidden_dim, hp.hidden_dim, 1, bias=True, activation=F.elu)])

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.dropout(layer(x), 0.5, training=True)
        return x


    
class Projection(nn.Module):
    def __init__(self, hdim, outdim):
        super(Projection, self).__init__()
        self.layers=nn.ModuleList([Conv1d(hdim, hdim, 5, p=0.5, activation=F.elu),
                                   Conv1d(hdim, hdim, 5, p=0.5, activation=F.elu),
                                   Conv1d(hdim, outdim, 5)])
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
        return torch.sigmoid(x)
    